In [1]:
import os
import shutil
import time

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from tensorflow.keras import Sequential
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
from tensorflow.keras.layers import Dense, Flatten, Softmax, Conv2D, Dropout, MaxPooling2D

print(tf.__version__)

2.3.0


In [2]:
mnist = tf.keras.datasets.mnist.load_data()
(x_train, y_train), (x_test, y_test) = mnist

In [3]:
HEIGHT, WIDTH = x_train[0].shape
NCLASSES = tf.size(tf.unique(y_train).y)
print("Image height x width is", HEIGHT, "x", WIDTH)
tf.print("There are", NCLASSES, "classes")

Image height x width is 28 x 28
There are 10 classes


In [4]:
def get_model():
 
    model = Sequential([
            Conv2D(64, kernel_size=3,
                   activation='relu', input_shape=(WIDTH, HEIGHT, 1)),
            MaxPooling2D(2),
            Conv2D(32, kernel_size=3,
                   activation='relu'),
            MaxPooling2D(2),
            Flatten(),
            Dense(400, activation='relu'),
            Dense(100, activation='relu'),
            Dropout(.25),
            Dense(10),
            Softmax()
        ])

    model.compile(optimizer='adam',
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
    
    return model

In [5]:
BUFFER_SIZE = 5000
BATCH_SIZE = 100

def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255
    image = tf.expand_dims(image, -1)
    return image, label


def load_dataset(training=True):
    """Loads MNIST dataset into a tf.data.Dataset"""
    (x_train, y_train), (x_test, y_test) = mnist
    x = x_train if training else x_test
    y = y_train if training else y_test
    # One-hot encode the classes
    y = tf.keras.utils.to_categorical(y, NCLASSES)
    dataset = tf.data.Dataset.from_tensor_slices((x, y))
    dataset = dataset.map(scale).batch(BATCH_SIZE)
    if training:
        dataset = dataset.shuffle(BUFFER_SIZE).repeat()
    return dataset

In [6]:
NUM_EPOCHS = 10
STEPS_PER_EPOCH = 100

model = get_model()
train_data = load_dataset()
validation_data = load_dataset(training=False)

OUTDIR = "mnist_digits/"
checkpoint_callback = ModelCheckpoint(
    OUTDIR, save_weights_only=True, verbose=1)
tensorboard_callback = TensorBoard(log_dir=OUTDIR)
t1 = time.perf_counter()
history = model.fit(
    train_data, 
    validation_data=validation_data,
    epochs=NUM_EPOCHS, 
    steps_per_epoch=STEPS_PER_EPOCH,
    verbose=2,
    callbacks=[checkpoint_callback, tensorboard_callback]
)
t2 = time.perf_counter()
print("training took: {:4.4f} secs.".format(t2 - t1))

Epoch 1/10
Instructions for updating:
use `tf.profiler.experimental.stop` instead.

Epoch 00001: saving model to mnist_digits/
100/100 - 38s - loss: 0.6817 - accuracy: 0.7822 - val_loss: 0.1847 - val_accuracy: 0.9444
Epoch 2/10

Epoch 00002: saving model to mnist_digits/
100/100 - 37s - loss: 0.1862 - accuracy: 0.9419 - val_loss: 0.1023 - val_accuracy: 0.9672
Epoch 3/10

Epoch 00003: saving model to mnist_digits/
100/100 - 37s - loss: 0.1381 - accuracy: 0.9597 - val_loss: 0.0983 - val_accuracy: 0.9671
Epoch 4/10

Epoch 00004: saving model to mnist_digits/
100/100 - 37s - loss: 0.1012 - accuracy: 0.9707 - val_loss: 0.0602 - val_accuracy: 0.9802
Epoch 5/10

Epoch 00005: saving model to mnist_digits/
100/100 - 37s - loss: 0.0755 - accuracy: 0.9770 - val_loss: 0.0698 - val_accuracy: 0.9776
Epoch 6/10

Epoch 00006: saving model to mnist_digits/
100/100 - 37s - loss: 0.0782 - accuracy: 0.9766 - val_loss: 0.0413 - val_accuracy: 0.9869
Epoch 7/10

Epoch 00007: saving model to mnist_digits/
100

In [7]:
input_form = """
<table>
<td style="border-style: none;">
<div style="border: solid 2px #666; width: 143px; height: 144px;">
<canvas width="140" height="140"></canvas>
</div></td>
<td style="border-style: none;">
<button onclick="clear_value()">Clear</button>
</td>
<td>
Value:
</td>
<td>
<h1><span id="predicted">-</span></h1>
</td>
</table>
"""

javascript = '''
<script type="text/Javascript">
    var pixels = [];
    for (var i = 0; i < 28*28; i++) pixels[i] = 0;
    var click = 0;

    var canvas = document.querySelector("canvas");
    canvas.addEventListener("mousemove", function(e){
        if (e.buttons == 1) {
            click = 1;
            predicted.textContent = "-"
            canvas.getContext("2d").fillStyle = "rgb(0,0,0)";
            canvas.getContext("2d").fillRect(e.offsetX, e.offsetY, 8, 8);
            x = Math.floor(e.offsetY * 0.2);
            y = Math.floor(e.offsetX * 0.2) + 1;
            for (var dy = 0; dy < 2; dy++){
                for (var dx = 0; dx < 2; dx++){
                    if ((x + dx < 28) && (y + dy < 28)){
                        pixels[(y+dy)+(x+dx)*28] = 1;
                    }
                }
            }
        } else {
            if (click == 1) set_value();
            click = 0;
        }
    });
    
    var predicted = document.querySelector("#predicted");
    
    function set_value(){
        predicted.textContent = ". . ."
        var result = ""
        for (var i = 0; i < 28*28; i++) result += pixels[i] + ","
        var kernel = IPython.notebook.kernel;
        kernel.execute("pred = np.array(["+result+"]).reshape(HEIGHT, WIDTH)");
        kernel.execute("pred = tf.cast(pred, tf.float32)");
        kernel.execute("pred = tf.expand_dims([pred], -1)");
        kernel.execute("pred = model.predict(pred)");
        kernel.execute("pred =  '{} (confidence: {:02.2f}%)'.format(np.argmax(pred), np.max(pred)*100)");
        var callbacks = {
            iopub: {
                output: (data) => {
                    predicted.textContent = data.content.text.trim();
                }
            }
        };
        kernel.execute("print(pred)", callbacks);
    }
    
    function clear_value(){
        canvas.getContext("2d").fillStyle = "rgb(255,255,255)";
        canvas.getContext("2d").fillRect(0, 0, 140, 140);
        for (var i = 0; i < 28*28; i++) pixels[i] = 0;
        predicted.textContent = "-"
    }
</script>
'''

from IPython.display import HTML
HTML(input_form + javascript)